{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l2hyuOxRwabt"
   },
   "source": [
    "# Ungraded Lab: Knowledge Distillation\n",
    "------------------------\n",
    " \n",
    "Welcome, during this ungraded lab you are going to perform a model compression technique known as **knowledge distillation** in which a `student` model \"learns\" from a more complex model known as the `teacher`. In particular you will:\n",
    "\n",
    "\n",
    "1. Define a `Distiller` class with the custom logic for the distillation process.\n",
    "2. Train the `teacher` model which is a CNN that implements regularization via dropout.\n",
    "3. Train a `student` model (a smaller version of the teacher without regularization) by using knowledge distillation.\n",
    "4. Train another `student` model from scratch without distillation called `student_scratch`.\n",
    "5. Compare the three students.\n",
    "\n",
    "\n",
    "This notebook is based on [this](https://keras.io/examples/vision/knowledge_distillation/) official Keras tutorial. \n",
    "\n",
    "If you want a more theoretical approach to this topic be sure to check this paper [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531). \n",
    "\n",
    "Let's get started!\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qAhJX9iLwabu"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SosaPG6jwabv"
   },
   "outputs": [],
   "source": [
    "# For setting random seeds\n",
    "import os\n",
    "os.environ['PYTHONHASHSEED']=str(42)\n",
    "\n",
    "# Libraries\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "# More random seed setup\n",
    "tf.random.set_seed(42)\n",
    "np.random.seed(42)\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8MsH7h6tqC2i"
   },
   "source": [
    "## Prepare the data\n",
    "\n",
    "For this lab you will use the [cats vs dogs](https://www.tensorflow.org/datasets/catalog/cats_vs_dogs) which is composed of many images of cats and dogs alongise their respective labels. \n",
    "\n",
    "Begin by downloading the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WGWF89iLwab0"
   },
   "outputs": [],
   "source": [
    "# Define train/test splits\n",
    "splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']\n",
    "\n",
    "# Download the dataset\n",
    "(train_examples, validation_examples, test_examples), info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True, split=splits)\n",
    "\n",
    "# Print useful information\n",
    "num_examples = info.splits['train'].num_examples\n",
    "num_classes = info.features['label'].num_classes\n",
    "\n",
    "print(f\"There are {num_examples} images for {num_classes} classes.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5LIucSJ8rKAG"
   },
   "source": [
    "Preprocess the data for training by normalizing pixel values, reshaping them and creating batches of data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cKhoLUfIR81q"
   },
   "outputs": [],
   "source": [
    "# Some global variables\n",
    "pixels = 224\n",
    "IMAGE_SIZE = (pixels, pixels)\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# Apply resizing and pixel normalization\n",
    "def format_image(image, label):\n",
    "    image = tf.image.resize(image, IMAGE_SIZE) / 255.0\n",
    "    return  image, label\n",
    "\n",
    "# Create batches of data\n",
    "train_batches = train_examples.shuffle(num_examples // 64).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
    "validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
    "test_batches = test_examples.map(format_image).batch(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lb5TgrJbZjJR"
   },
   "source": [
    "## Code the custom `Distiller` model\n",
    "\n",
    "In order to implement the distillation process you will create a custom Keras model which you will name `Distiller`. In order to do this you need to override some of the vanilla methods of a `keras.Model` to include the custom logic for the knowledge distillation. You need to override these methods:\n",
    "- `compile`: This model needs some extra parameters to be compiled such as the teacher and student losses, the alpha and the temperature.\n",
    "- `train_step`: Controls how the model is trained. This will be where the actual knowledge distillation logic will be found. This method is what is called when you do `model.fit`.\n",
    "- `test_step`: Controls the evaluation of the model. This method is what is called when you do `model.evaluate`.\n",
    "\n",
    "To learn more about customizing models check out the [official docs](https://keras.io/guides/customizing_what_happens_in_fit/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EdZ7JiqEwabw"
   },
   "outputs": [],
   "source": [
    "class Distiller(keras.Model):\n",
    "\n",
    "  # Needs both the student and teacher models to create an instance of this class\n",
    "  def __init__(self, student, teacher):\n",
    "      super(Distiller, self).__init__()\n",
    "      self.teacher = teacher\n",
    "      self.student = student\n",
    "\n",
    "\n",
    "  # Will be used when calling model.compile()\n",
    "  def compile(self, optimizer, metrics, student_loss_fn,\n",
    "              distillation_loss_fn, alpha, temperature):\n",
    "\n",
    "      # Compile using the optimizer and metrics\n",
    "      super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)\n",
    "      \n",
    "      # Add the other params to the instance\n",
    "      self.student_loss_fn = student_loss_fn\n",
    "      self.distillation_loss_fn = distillation_loss_fn\n",
    "      self.alpha = alpha\n",
    "      self.temperature = temperature\n",
    "\n",
    "\n",
    "  # Will be used when calling model.fit()\n",
    "  def train_step(self, data):\n",
    "      # Data is expected to be a tuple of (features, labels)\n",
    "      x, y = data\n",
    "\n",
    "      # Vanilla forward pass of the teacher\n",
    "      # Note that the teacher is NOT trained\n",
    "      teacher_predictions = self.teacher(x, training=False)\n",
    "\n",
    "      # Use GradientTape to save gradients\n",
    "      with tf.GradientTape() as tape:\n",
    "          # Vanilla forward pass of the student\n",
    "          student_predictions = self.student(x, training=True)\n",
    "\n",
    "          # Compute vanilla student loss\n",
    "          student_loss = self.student_loss_fn(y, student_predictions)\n",
    "          \n",
    "          # Compute distillation loss\n",
    "          # Should be KL divergence between logits softened by a temperature factor\n",
    "          distillation_loss = self.distillation_loss_fn(\n",
    "              tf.nn.softmax(teacher_predictions / self.temperature, axis=1),\n",
    "              tf.nn.softmax(student_predictions / self.temperature, axis=1))\n",
    "\n",
    "          # Compute loss by weighting the two previous losses using the alpha param\n",
    "          loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss\n",
    "\n",
    "      # Use tape to calculate gradients for student\n",
    "      trainable_vars = self.student.trainable_variables\n",
    "      gradients = tape.gradient(loss, trainable_vars)\n",
    "\n",
    "      # Update student weights \n",
    "      # Note that this done ONLY for the student\n",
    "      self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
    "\n",
    "      # Update the metrics\n",
    "      self.compiled_metrics.update_state(y, student_predictions)\n",
    "\n",
    "      # Return a performance dictionary\n",
    "      # You will see this being outputted during training\n",
    "      results = {m.name: m.result() for m in self.metrics}\n",
    "      results.update({\"student_loss\": student_loss, \"distillation_loss\": distillation_loss})\n",
    "      return results\n",
    "\n",
    "\n",
    "  # Will be used when calling model.evaluate()\n",
    "  def test_step(self, data):\n",
    "      # Data is expected to be a tuple of (features, labels)\n",
    "      x, y = data\n",
    "\n",
    "      # Use student to make predictions\n",
    "      # Notice that the training param is set to False\n",
    "      y_prediction = self.student(x, training=False)\n",
    "\n",
    "      # Calculate student's vanilla loss\n",
    "      student_loss = self.student_loss_fn(y, y_prediction)\n",
    "\n",
    "      # Update the metrics\n",
    "      self.compiled_metrics.update_state(y, y_prediction)\n",
    "\n",
    "      # Return a performance dictionary\n",
    "      # You will see this being outputted during inference\n",
    "      results = {m.name: m.result() for m in self.metrics}\n",
    "      results.update({\"student_loss\": student_loss})\n",
    "      return results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f1QXmNmisKNG"
   },
   "source": [
    "## Teacher and student models\n",
    "\n",
    "For the models you will use a standard CNN architecture that implements regularization via some dropout layers (in the case of the teacher), but it could be any Keras model. \n",
    "\n",
    "Define the `create_model` functions to create models with the desired architecture using Keras' [Sequential Model](https://keras.io/guides/sequential_model/).\n",
    "\n",
    "Notice that `create_small_model` returns a simplified version of the model (in terms of number of layers and absence of regularization) that `create_big_model` returns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "35GyhKrgwt0o"
   },
   "outputs": [],
   "source": [
    "# Teacher model\n",
    "def create_big_model():\n",
    "  tf.random.set_seed(42)\n",
    "  model = keras.models.Sequential([\n",
    "    keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),\n",
    "    keras.layers.MaxPooling2D((2, 2)),\n",
    "    keras.layers.Conv2D(64, (3, 3), activation='relu'),\n",
    "    keras.layers.MaxPooling2D((2, 2)),\n",
    "    keras.layers.Dropout(0.2),\n",
    "    keras.layers.Conv2D(64, (3, 3), activation='relu'),\n",
    "    keras.layers.MaxPooling2D((2, 2)),\n",
    "    keras.layers.Conv2D(128, (3, 3), activation='relu'),\n",
    "    keras.layers.MaxPooling2D((2, 2)),\n",
    "    keras.layers.Dropout(0.5),\n",
    "    keras.layers.Flatten(),\n",
    "    keras.layers.Dense(512, activation='relu'),\n",
    "    keras.layers.Dense(2)\n",
    "  ])\n",
    "\n",
    "  return model\n",
    "\n",
    "\n",
    "\n",
    "# Student model\n",
    "def create_small_model():\n",
    "  tf.random.set_seed(42)\n",
    "  model = keras.models.Sequential([\n",
    "    keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),\n",
    "    keras.layers.MaxPooling2D((2, 2)),\n",
    "    keras.layers.Flatten(),\n",
    "    keras.layers.Dense(2)\n",
    "  ])\n",
    "\n",
    "  return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8FsetAiyvHlr"
   },
   "source": [
    "There are two important things to notice:\n",
    "- The last layer does not have an softmax activation because the raw logits are needed for the knowledge distillation.\n",
    "- Regularization via dropout layers will be applied to the teacher but NOT to the student. This is because the student should be able to learn this regularization through the distillation process.\n",
    "\n",
    "Remember that the student model can be thought of as a simplified (or compressed) version of the teacher model.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HazdkHp9j7Ur"
   },
   "outputs": [],
   "source": [
    "# Create the teacher\n",
    "teacher = create_big_model()\n",
    "\n",
    "# Plot architecture\n",
    "keras.utils.plot_model(teacher, rankdir=\"LR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Bywn32D7kZ9H"
   },
   "outputs": [],
   "source": [
    "# Create the student\n",
    "student = create_small_model()\n",
    "\n",
    "# Plot architecture\n",
    "keras.utils.plot_model(student, rankdir=\"LR\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3lxJnnI4xs-s"
   },
   "source": [
    "Check the actual difference in number of trainable parameters (weights and biases) between both models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ed8Sd21vvwSK"
   },
   "outputs": [],
   "source": [
    "# Calculates number of trainable params for a given model\n",
    "def num_trainable_params(model):\n",
    "  return np.sum([np.prod(v.get_shape()) for v in model.trainable_weights])\n",
    "\n",
    "\n",
    "student_params = num_trainable_params(student)\n",
    "teacher_params = num_trainable_params(teacher)\n",
    "\n",
    "print(f\"Teacher model has: {teacher_params} trainable params.\\n\")\n",
    "print(f\"Student model has: {student_params} trainable params.\\n\")\n",
    "print(f\"Teacher model is roughly {teacher_params//student_params} times bigger than the student model.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "O_O66k7dwab1"
   },
   "source": [
    "### Train the teacher\n",
    "\n",
    "In knowledge distillation it is assumed that the teacher has already been trained so the natural first step is to train the teacher. You will do so for a total of 8 epochs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cWtaALBbwab1"
   },
   "outputs": [],
   "source": [
    "# Compile the teacher model\n",
    "teacher.compile(\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # Notice from_logits param is set to True\n",
    "    optimizer=keras.optimizers.Adam(),\n",
    "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
    ")\n",
    "\n",
    "# Fit the model and save the training history (will take from 5 to 10 minutes depending on the GPU you were assigned to)\n",
    "teacher_history = teacher.fit(train_batches, epochs=8, validation_data=validation_batches)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9kSMig49wab2"
   },
   "source": [
    "## Train a student from scratch for reference\n",
    "\n",
    "In order to assess the effectiveness of the distillation process, train a model that is equivalent to the student but without doing knowledge distillation. Notice that the training is done for only 5 epochs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BPb3wE2nwab3"
   },
   "outputs": [],
   "source": [
    "# Create student_scratch model with the same characteristics as the original student\n",
    "student_scratch = create_small_model()\n",
    "\n",
    "# Compile it\n",
    "student_scratch.compile(\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=keras.optimizers.Adam(),\n",
    "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
    ")\n",
    "\n",
    "# Train and evaluate student trained from scratch (will take around 3 mins with GPU enabled)\n",
    "student_scratch_history = student_scratch.fit(train_batches, epochs=5, validation_data=validation_batches)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2BdD9K57wab2"
   },
   "source": [
    "## Knowledge Distillation\n",
    "\n",
    "To perform the knowledge distillation process you will use the custom model you previously coded. To do so, begin by creating an instance of the `Distiller` class and passing in the student and teacher models. Then compile it with the appropiate parameters and train it!\n",
    "\n",
    "The two student models are trained for only 5 epochs unlike the teacher that was trained for 8. This is done to showcase that the knowledge distillation allows for quicker training times as the student learns from an already trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D7EqhGlAwab2"
   },
   "outputs": [],
   "source": [
    "# Create Distiller instance\n",
    "distiller = Distiller(student=student, teacher=teacher)\n",
    "\n",
    "# Compile Distiller model\n",
    "distiller.compile(\n",
    "    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=keras.optimizers.Adam(),\n",
    "    metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
    "    distillation_loss_fn=keras.losses.KLDivergence(),\n",
    "    alpha=0.05,\n",
    "    temperature=5,\n",
    ")\n",
    "\n",
    "# Distill knowledge from teacher to student (will take around 3 mins with GPU enabled)\n",
    "distiller_history = distiller.fit(train_batches, epochs=5, validation_data=validation_batches)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "voTxT0cIxCYx"
   },
   "source": [
    "## Comparing the models\n",
    "\n",
    "To compare the models you can check the `sparse_categorical_accuracy` of each one on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7O4xXZlhxp92"
   },
   "outputs": [],
   "source": [
    "# Compute accuracies\n",
    "student_scratch_acc = student_scratch.evaluate(test_batches, return_dict=True).get(\"sparse_categorical_accuracy\")\n",
    "distiller_acc = distiller.evaluate(test_batches, return_dict=True).get(\"sparse_categorical_accuracy\")\n",
    "teacher_acc = teacher.evaluate(test_batches, return_dict=True).get(\"sparse_categorical_accuracy\")\n",
    "\n",
    "# Print results\n",
    "print(f\"\\n\\nTeacher achieved a sparse_categorical_accuracy of {teacher_acc*100:.2f}%.\\n\")\n",
    "print(f\"Student with knowledge distillation achieved a sparse_categorical_accuracy of {distiller_acc*100:.2f}%.\\n\")\n",
    "print(f\"Student without knowledge distillation achieved a sparse_categorical_accuracy of {student_scratch_acc*100:.2f}%.\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JTDRmrXWwab3"
   },
   "source": [
    "The teacher model yields a higger accuracy than the two student models. This is expected since it was trained for more epochs while using a bigger architecture.\n",
    "\n",
    "Notice that the student without distillation was outperfomed by the student with knowledge distillation. \n",
    "\n",
    "Since you saved the training history of each model you can create a plot for a better comparison of the two student models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p-m8dvwS92rF"
   },
   "outputs": [],
   "source": [
    "# Get relevant metrics from a history\n",
    "def get_metrics(history):\n",
    "  history = history.history\n",
    "  acc = history['sparse_categorical_accuracy']\n",
    "  val_acc = history['val_sparse_categorical_accuracy']\n",
    "  return acc, val_acc\n",
    "\n",
    "\n",
    "# Plot training and evaluation metrics given a dict of histories\n",
    "def plot_train_eval(history_dict):\n",
    "  \n",
    "  metric_dict = {}\n",
    "\n",
    "  for k, v in history_dict.items():\n",
    "    acc, val_acc= get_metrics(v)\n",
    "    metric_dict[f'{k} training acc'] = acc\n",
    "    metric_dict[f'{k} eval acc'] = val_acc\n",
    "\n",
    "  acc_plot = pd.DataFrame(metric_dict)\n",
    "  \n",
    "  acc_plot = sns.lineplot(data=acc_plot, markers=True)\n",
    "  acc_plot.set_title('training vs evaluation accuracy')\n",
    "  acc_plot.set_xlabel('epoch')\n",
    "  acc_plot.set_ylabel('sparse_categorical_accuracy')\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "# Plot for comparing the two student models\n",
    "plot_train_eval({\n",
    "    \"distilled\": distiller_history,\n",
    "    \"student_scratch\": student_scratch_history,\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tm1VrbjK16n6"
   },
   "source": [
    "This plot is very interesting because it shows that the distilled version outperformed the unmodified one in almost all of the epochs when using the evaluation set. Alongside this, the student without distillation yields a bigger training accuracy, which is a sign that it is overfitting more than the distilled model. **This hints that the distilled model was able to learn from the regularization that the teacher implemented!** Pretty cool, right?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SGDr0PoC1nuP"
   },
   "source": [
    "-----------------------------\n",
    "**Congratulations on finishing this ungraded lab!** Now you should have a clearer understanding of what Knowledge Distillation is and how it can be implemented using Tensorflow and Keras. \n",
    "\n",
    "This process is widely used for model compression and has proven to perform really well. In fact you might have heard about [`DistilBert`](https://huggingface.co/transformers/model_doc/distilbert.html), which is a smaller, faster, cheaper and lighter of BERT.\n",
    "\n",
    "\n",
    "**Keep it up!**"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
